# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import torch
from tensordict import NestedKey, TensorDictBase, unravel_key
from tensordict.nn import TensorDictModuleBase
from tensordict.utils import expand_right
from torch import nn


def _get_reward(
    gamma: float,
    reward: torch.Tensor,
    done: torch.Tensor,
    max_steps: int,
):
    """Sums the rewards up to max_steps in the future with a gamma decay.

    Supports multiple consecutive trajectories.

    Assumes that the time dimension is the *last* dim of reward and done.
    """
    filt = torch.tensor(
        [gamma**i for i in range(max_steps + 1)],
        device=reward.device,
        dtype=reward.dtype,
    ).view(1, 1, -1)
    # make one done mask per trajectory
    done_cumsum = done.cumsum(-1)
    done_cumsum = torch.cat(
        [torch.zeros_like(done_cumsum[..., :1]), done_cumsum[..., :-1]], -1
    )
    num_traj = done_cumsum.max().item() + 1
    done_cumsum = done_cumsum.expand(num_traj, *done.shape)
    traj_ids = done_cumsum == torch.arange(
        num_traj, device=done.device, dtype=done_cumsum.dtype
    ).view(num_traj, *[1 for _ in range(done_cumsum.ndim - 1)])
    # an expanded reward tensor where each index along dim 0 is a different trajectory
    # Note: rewards could have a different shape than done (e.g. multi-agent with a single
    # done per group).
    # we assume that reward has the same leading dimension as done.
    if reward.shape != traj_ids.shape[1:]:
        # We'll expand the ids on the right first
        traj_ids_expand = expand_right(traj_ids, (num_traj, *reward.shape))
        reward_traj = traj_ids_expand * reward
        # we must make sure that the last dimension of the reward is the time
        reward_traj = reward_traj.transpose(-1, traj_ids.ndim - 1)
    else:
        # simpler use case: reward shape and traj_ids match
        reward_traj = traj_ids * reward

    reward_traj = torch.nn.functional.pad(reward_traj, [0, max_steps], value=0.0)
    shape = reward_traj.shape[:-1]
    if len(shape) > 1:
        reward_traj = reward_traj.flatten(0, reward_traj.ndim - 2)
    reward_traj = reward_traj.unsqueeze(-2)
    summed_rewards = torch.conv1d(reward_traj, filt)
    summed_rewards = summed_rewards.squeeze(-2)
    if len(shape) > 1:
        summed_rewards = summed_rewards.unflatten(0, shape)
    # let's check that our summed rewards have the right size
    if reward.shape != traj_ids.shape[1:]:
        summed_rewards = summed_rewards.transpose(-1, traj_ids.ndim - 1)
        summed_rewards = (summed_rewards * traj_ids_expand).sum(0)
    else:
        summed_rewards = (summed_rewards * traj_ids).sum(0)

    # time_to_obs is the tensor of the time delta to the next obs
    # 0 = take the next obs (ie do nothing)
    # 1 = take the obs after the next
    time_to_obs = (
        traj_ids.flip(-1).cumsum(-1).clamp_max(max_steps + 1).flip(-1) * traj_ids
    )
    time_to_obs = time_to_obs.sum(0)
    time_to_obs = time_to_obs - 1
    return summed_rewards, time_to_obs


class MultiStep(nn.Module):
    """Multistep reward transform.

    Presented in

    | Sutton, R. S. 1988. Learning to predict by the methods of temporal differences. Machine learning 3(1):9–44.

    This module maps the "next" observation to the t + n "next" observation.
    It is an identity transform whenever :attr:`n_steps` is 0.

    Args:
        gamma (:obj:`float`): Discount factor for return computation
        n_steps (integer): maximum look-ahead steps.

    .. note:: This class is meant to be used within a ``DataCollector``.
        It will only treat the data passed to it at the end of a collection,
        and ignore data preceding that collection or coming in the next batch.
        As such, results on the last steps of the batch may likely be biased
        by the early truncation of the trajectory.
        To mitigate this effect, please use :class:`~torchrl.envs.transforms.MultiStepTransform`
        within the replay buffer instead.

    Examples:
        >>> from torchrl.collectors import SyncDataCollector, RandomPolicy
        >>> from torchrl.data.postprocs import MultiStep
        >>> from torchrl.envs import GymEnv, TransformedEnv, StepCounter
        >>> env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
        >>> env.set_seed(0)
        >>> collector = SyncDataCollector(env, policy=RandomPolicy(env.action_spec),
        ...     frames_per_batch=10, total_frames=2000, postproc=MultiStep(n_steps=4, gamma=0.99))
        >>> for data in collector:
        ...     break
        >>> print(data["step_count"])
        tensor([[0],
                [1],
                [2],
                [3],
                [4],
                [5],
                [6],
                [7],
                [8],
                [9]])
        >>> # the next step count is shifted by 3 steps in the future
        >>> print(data["next", "step_count"])
        tensor([[ 5],
                [ 6],
                [ 7],
                [ 8],
                [ 9],
                [10],
                [10],
                [10],
                [10],
                [10]])

    """

    def __init__(
        self,
        gamma: float,
        n_steps: int,
    ):
        super().__init__()
        if n_steps <= 0:
            raise ValueError("n_steps must be a non-negative integer.")
        if not (gamma > 0 and gamma <= 1):
            raise ValueError(f"got out-of-bounds gamma decay: gamma={gamma}")

        self.gamma = gamma
        self.n_steps = n_steps
        self.register_buffer(
            "gammas",
            torch.tensor(
                [gamma**i for i in range(n_steps + 1)],
                dtype=torch.float,
            ).reshape(1, 1, -1),
        )
        self.done_key = "done"
        self.done_keys = ("done", "terminated", "truncated")
        self.reward_keys = ("reward",)
        self.mask_key = ("collector", "mask")

    def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
        """Re-writes a tensordict following the multi-step transform.

        Args:
            tensordict: :class:`tensordict.TensorDictBase` instance with
                ``[*Batch x Time-steps] shape.
                The TensorDict must contain a ``("next", "reward")`` and
                ``("next", "done")`` keys.
                All keys that are contained within the "next" nested tensordict
                will be shifted by (at most) :attr:`~.n_steps` frames.
                The TensorDict will also be updated with new key-value pairs:

                - gamma: indicating the discount to be used for the next
                  reward;
                - nonterminal: boolean value indicating whether a step is
                  non-terminal (not done or not last of trajectory);
                - original_reward: previous reward collected in the
                  environment (i.e. before multi-step);
                - The "reward" values will be replaced by the newly computed
                  rewards.

                The ``"done"`` key can have either the shape of the tensordict
                OR the shape of the tensordict followed by a singleton
                dimension OR the shape of the tensordict followed by other
                dimensions. In the latter case, the tensordict *must* be
                compatible with a reshape that follows the done shape (ie. the
                leading dimensions of every tensor it contains must match the
                shape of the ``"done"`` entry).
                The ``"reward"`` tensor can have either the shape of the
                tensordict (or done state) or this shape followed by a singleton
                dimension.

        Returns:
            in-place transformation of the input tensordict.

        """
        return _multi_step_func(
            tensordict,
            done_key=self.done_key,
            done_keys=self.done_keys,
            reward_keys=self.reward_keys,
            mask_key=self.mask_key,
            n_steps=self.n_steps,
            gamma=self.gamma,
        )


def _multi_step_func(
    tensordict,
    *,
    done_key,
    done_keys,
    reward_keys,
    mask_key,
    n_steps,
    gamma,
):
    # in accordance with common understanding of what n_steps should be
    n_steps = n_steps - 1
    tensordict = tensordict.clone(False)
    done = tensordict.get(("next", done_key))

    # we'll be using the done states to index the tensordict.
    # if the shapes don't match we're in trouble.
    ndim = tensordict.ndim
    if done.shape != tensordict.shape:
        if done.shape[-1] == 1 and done.shape[:-1] == tensordict.shape:
            done = done.squeeze(-1)
        else:
            try:
                # let's try to reshape the tensordict
                tensordict.batch_size = done.shape
                tensordict = tensordict.transpose(ndim - 1, tensordict.ndim - 1)
                done = tensordict.get(("next", done_key))
            except Exception as err:
                raise RuntimeError(
                    "tensordict shape must be compatible with the done's shape "
                    "(trailing singleton dimension excluded)."
                ) from err

    if mask_key is not None:
        mask = tensordict.get(mask_key, None)
    else:
        mask = None

    *batch, T = tensordict.batch_size

    summed_rewards = []
    for reward_key in reward_keys:
        reward = tensordict.get(("next", reward_key))

        # sum rewards
        summed_reward, time_to_obs = _get_reward(gamma, reward, done, n_steps)
        summed_rewards.append(summed_reward)

    idx_to_gather = torch.arange(
        T, device=time_to_obs.device, dtype=time_to_obs.dtype
    ).expand(*batch, T)
    idx_to_gather = idx_to_gather + time_to_obs

    # idx_to_gather looks like  tensor([[ 2,  3,  4,  5,  5,  5,  8,  9, 10, 10, 10]])
    # with a done state         tensor([[ 0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  1]])
    # meaning that the first obs will be replaced by the third, the second by the fourth etc.
    # The fifth remains the fifth as it is terminal
    tensordict_gather = (
        tensordict.get("next")
        .exclude(*reward_keys, *done_keys)
        .gather(-1, idx_to_gather)
    )

    tensordict.set("steps_to_next_obs", time_to_obs + 1)
    for reward_key, summed_reward in zip(reward_keys, summed_rewards):
        tensordict.rename_key_(("next", reward_key), ("next", "original_reward"))
        tensordict.set(("next", reward_key), summed_reward)

    tensordict.get("next").update(tensordict_gather)
    tensordict.set("gamma", gamma ** (time_to_obs + 1))
    nonterminal = time_to_obs != 0
    if mask is not None:
        mask = mask.view(*batch, T)
        nonterminal[~mask] = False
    tensordict.set("nonterminal", nonterminal)
    if tensordict.ndim != ndim:
        tensordict = tensordict.apply(
            lambda x: x.transpose(ndim - 1, tensordict.ndim - 1),
            batch_size=done.transpose(ndim - 1, tensordict.ndim - 1).shape,
        )
        tensordict.batch_size = tensordict.batch_size[:ndim]
    return tensordict


class DensifyReward(TensorDictModuleBase):
    """A util to reassign the reward at done state to the rest of the trajectory.

    This transform is to be used with sparse rewards to assign a reward to each step of a trajectory when only the
    reward at `done` is non-null.

    .. note:: The class calls the :func:`~torchrl.objectives.value.functional.reward2go` function, which will
        also sum intermediate rewards. Make sure you understand what the `reward2go` function returns before using
        this module.

    Args:
        reward_key (NestedKey, optional): The key in the input TensorDict where the reward is stored.
            Defaults to `"reward"`.
        done_key (NestedKey, optional): The key in the input TensorDict where the done flag is stored.
            Defaults to `"done"`.
        reward_key_out (NestedKey | None, optional): The key in the output TensorDict where the reassigned reward
            will be stored. If None, it defaults to the value of `reward_key`.
            Defaults to `None`.
        time_dim (int, optional): The dimension in the input TensorDict where the time is unrolled.
            Defaults to `2`.
        discount (float, optional): The discount factor to use for computing the discounted cumulative sum of rewards.
            Defaults to `1.0` (no discounting).

    Returns:
        TensorDict: The input TensorDict with the reassigned reward stored under the key specified by `reward_key_out`.

    Examples:
        >>> import torch
        >>> from tensordict import TensorDict
        >>>
        >>> from torchrl.data import DensifyReward
        >>>
        >>> # Create a sample TensorDict
        >>> tensordict = TensorDict({
        ...     "next": {
        ...         "reward": torch.zeros(10, 1),
        ...         "done": torch.zeros(10, 1, dtype=torch.bool)
        ...     }
        ... }, batch_size=[10])
        >>> # Set some done flags and rewards
        >>> tensordict["next", "done"][[3, 7]] = True
        >>> tensordict["next", "reward"][3] = 3
        >>> tensordict["next", "reward"][7] = 7
        >>> # Create an instance of LastRewardToTraj
        >>> last_reward_to_traj = DensifyReward()
        >>> # Apply the transform
        >>> new_tensordict = last_reward_to_traj(tensordict)
        >>> # Print the reassigned rewards
        >>> print(new_tensordict["next", "reward"])
        tensor([[3.],
                [3.],
                [3.],
                [3.],
                [7.],
                [7.],
                [7.],
                [7.],
                [0.],
                [0.]])

    """

    def __init__(
        self,
        *,
        reward_key: NestedKey = "reward",
        done_key: NestedKey = "done",
        reward_key_out: NestedKey | None = None,
        time_dim: int = 2,
        discount: float = 1.0,
    ):
        from torchrl.objectives.value.functional import reward2go

        super().__init__()
        self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
        if reward_key_out is None:
            reward_key_out = reward_key
        self.out_keys = [unravel_key(reward_key_out)]
        self.time_dim = time_dim
        self.discount = discount
        self.reward2go = reward2go

    def forward(self, tensordict):
        # Get done
        done = tensordict.get(("next", self.in_keys[1]))
        # Get reward
        reward = tensordict.get(("next", self.in_keys[0]))
        if reward.shape != done.shape:
            raise RuntimeError(
                f"reward and done state are expected to have the same shape. Got reard.shape={reward.shape} "
                f"and done.shape={done.shape}."
            )
        reward = self.reward2go(reward, done, time_dim=-2, gamma=self.discount)
        tensordict.set(("next", self.out_keys[0]), reward)
        return tensordict
